from typing import Tuple

import numpy as np
import torch
import torch.utils.data as data
import sys
import random
from torch.utils.data.sampler import Sampler
from torch.utils.data import DataLoader

sys.path.append('./')
from auxiliary.utils import hwc_chw, gamma_correct, brg_to_rgb
from classes.data.DataAugmenter import DataAugmenter

#1.mate30 2.P30pro 3.iphone 4.vivo 5.Xiaomi11 6. Xiaomi13 7.all

class TestCTADataset(data.Dataset):

    def __init__(self, input_size:Tuple = (224, 224), device: int = 1):
        dataset_device = ['HuaweiMate30', 'HuaweiP30PRO', 'iphone14pm', 'vivoiqooneo5', 'Xiaomi11PRO', 'Xiaomi13']
        num_device = ['mate30', 'P30pro', 'iphonepm', 'vivo', 'xiaomi11pro', 'xiaomi13']
        self.__device = device
        path_to_dataset = '/home/wsy/examples/'
        #path_to_dataset = '/mnt/disklcx123/NPY/iphone14pm/'
        self.__input_size = input_size
        self._paths_to_seqs = []
        self._nums_to_seqs = []
        self.__da = DataAugmenter(self.__input_size)
        test_path = '/home/wsy/TAWB/dataset/CTA-Set/test_'+num_device[device-1]+'.npy'
        test_info = np.load(test_path, allow_pickle=True).item()
        test_ids = test_info['id']
        test_nums = test_info['num']
        for i in range(len(test_ids)):
            id = test_ids[i]
            num = test_nums[i]
            for j in range(1, num+1):
                self._paths_to_seqs.append(path_to_dataset + str(id) + ',' + str(j))
                self._nums_to_seqs.append(num)

    def __getitem__(self, index: int) -> Tuple:
        path_to_sequence = self._paths_to_seqs[index]
        num_to_sequence = self._nums_to_seqs[index]
        path_to_frame = str(path_to_sequence.split(',')[0])
        label_path = path_to_frame + '/illu_mat.npy'
        illums = np.load(label_path, allow_pickle=True).item()
        id = int(path_to_sequence.split(',')[-1])
        files_seq = []
        if id == 1:
            files_seq.append(path_to_frame+'/'+str(id)+'.dng.npy')
            files_seq.append(path_to_frame+'/'+str(id)+'.dng.npy')
            files_seq.append(path_to_frame+'/'+str(id+1)+'.dng.npy')     
        elif id == num_to_sequence:
            files_seq.append(path_to_frame+'/'+str(id-1)+'.dng.npy')
            files_seq.append(path_to_frame+'/'+str(id)+'.dng.npy')
            files_seq.append(path_to_frame+'/'+str(id)+'.dng.npy')  
        else:
            files_seq.append(path_to_frame+'/'+str(id-1)+'.dng.npy')
            files_seq.append(path_to_frame+'/'+str(id)+'.dng.npy')
            files_seq.append(path_to_frame+'/'+str(id+1)+'.dng.npy')
        images = [np.array(np.load(file), dtype='float32') for file in files_seq]
        seq = np.array(images, dtype='float32')
        illuminant = np.array(illums[str(id)], dtype='float32')

        mimic = torch.from_numpy(self.__da.augment_mimic(seq).transpose((0, 3, 1, 2)).copy())

        seq = self.__da.resize_sequence(seq)

        seq = np.clip(seq, 0.0, 255.0) * (1.0 / 255)
        seq = hwc_chw(gamma_correct(brg_to_rgb(seq)))

        seq = torch.from_numpy(seq.copy())
        illuminant = torch.from_numpy(illuminant.copy())

        return seq, mimic, illuminant, path_to_sequence

    def __len__(self) -> int:
        return len(self._paths_to_seqs)


class TestAllCTADataset(data.Dataset):

    def __init__(self, input_size:Tuple = (224, 224)):
        dataset_device = ['HuaweiMate30', 'HuaweiP30PRO', 'iphone14pm', 'vivoiqooneo5', 'Xiaomi11PRO', 'Xiaomi13']
        num_device = ['mate30', 'P30pro', 'iphonepm', 'vivo', 'xiaomi11pro', 'xiaomi13']
        self.__input_size = input_size
        self._paths_to_seqs = []
        self._nums_to_seqs = []
        self.__da = DataAugmenter(self.__input_size)
        for device in range(1, 7):
            path_to_dataset = '/mnt/disklcx123/NPY1/' + dataset_device[device-1] +'/'
            #path_to_dataset = '/mnt/disklcx123/NPY/iphone14pm/'
            test_path = '/home/wsy/TAWB/dataset/CTA-Set/test_'+num_device[device-1]+'.npy'
            test_info = np.load(test_path, allow_pickle=True).item()
            test_ids = test_info['id']
            test_nums = test_info['num']
            for i in range(len(test_ids)):
                id = test_ids[i]
                num = test_nums[i]
                for j in range(1, num+1):
                    self._paths_to_seqs.append(path_to_dataset + str(id) + ',' + str(j))
                    self._nums_to_seqs.append(num)

    def __getitem__(self, index: int) -> Tuple:
        path_to_sequence = self._paths_to_seqs[index]
        num_to_sequence = self._nums_to_seqs[index]
        path_to_frame = str(path_to_sequence.split(',')[0])
        label_path = path_to_frame + '/illu_mat.npy'
        illums = np.load(label_path, allow_pickle=True).item()
        id = int(path_to_sequence.split(',')[-1])
        files_seq = []
        if id == 1:
            files_seq.append(path_to_frame+'/'+str(id)+'.dng.npy')
            files_seq.append(path_to_frame+'/'+str(id)+'.dng.npy')
            files_seq.append(path_to_frame+'/'+str(id+1)+'.dng.npy')     
        elif id == num_to_sequence:
            files_seq.append(path_to_frame+'/'+str(id-1)+'.dng.npy')
            files_seq.append(path_to_frame+'/'+str(id)+'.dng.npy')
            files_seq.append(path_to_frame+'/'+str(id)+'.dng.npy')  
        else:
            files_seq.append(path_to_frame+'/'+str(id-1)+'.dng.npy')
            files_seq.append(path_to_frame+'/'+str(id)+'.dng.npy')
            files_seq.append(path_to_frame+'/'+str(id+1)+'.dng.npy')
        images = [np.array(np.load(file), dtype='float32') for file in files_seq]
        seq = np.array(images, dtype='float32')
        illuminant = np.array(illums[str(id)], dtype='float32')

        mimic = torch.from_numpy(self.__da.augment_mimic(seq).transpose((0, 3, 1, 2)).copy())

        seq = self.__da.resize_sequence(seq)

        seq = np.clip(seq, 0.0, 255.0) * (1.0 / 255)
        seq = hwc_chw(gamma_correct(brg_to_rgb(seq)))

        seq = torch.from_numpy(seq.copy())
        illuminant = torch.from_numpy(illuminant.copy())

        return seq, mimic, illuminant, path_to_sequence

    def __len__(self) -> int:
        return len(self._paths_to_seqs)

# training_set = TestCTADataset()
# training_set.__getitem__(1)


